import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, default_data_collator, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, PrefixTuningConfig, TaskType
from datasets import load_dataset
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

from torch.utils.data import ConcatDataset
from typing import List
import torch
import random
import numpy as np
import evaluate
from transformers import T5Tokenizer
from transformers import T5ForConditionalGeneration
from torch.utils.data import DataLoader, random_split
from typing import List
import copy
import argparse
import re
import shutil

import torch.nn.functional as F

from scipy import spatial
from statistics import mean
print("started")

w_cos, w_var, w_varp = [], [], []
g_cos, g_var, g_varp = [], [], []
wg_cos, wg_var, wg_varp = [], [], []
cos_wgrad = []
avg_acc = []
ap = argparse.ArgumentParser()
ap.add_argument("--num_virtual_tokens", required=False, type=int,
                help="number of virtual tokens in prefix encoding layer")
ap.add_argument("--lr", required=False, type=float,
                help="learning rate of centralized model")
ap.add_argument("--batch_size", required=False, type=int, help="batch_size")
ap.add_argument("--local_epoch", required=False, type=int,
                help="number of local epoch for clients")
ap.add_argument(
    "--num_rounds",
    required=False,
    type=int,
    help="number of training rounds")
ap.add_argument(
    "--mu",
    required=False,
    type=float,
    help="hyperparameter in FedProx")
args = ap.parse_args()

os.environ["TOKENIZERS_PARALLELISM"] = "false"


device = "cuda:0"
print("device: ", device)
model_name_or_path = "t5-large"
tokenizer_name_or_path = "t5-large"
max_length = 512

num_virtual_tokens = 20  # default
lr = args.lr  #0.01  # default
batch_size = 32  # default
# if args.batch_size:
#     batch_size = args.batch_size

num_rounds = 50  # default
local_epochs = 1
mu = 0.001

seed = #your random seed#
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)


def Remove_Columns(example):
    del example["idx"]
    return example
def remove_syn_col(ex):
    del ex['score']
    del ex['start_prompt']
    return ex 
def convert_labels_to_int(example):
    example['label'] = int(example['label'])
    return example

#synthetic#
cola_path = 'synthetic-data/CoLA/train.json'
cola_syn = load_dataset('json', data_files=cola_path)
cola_syn_classes = ['unacceptable', 'acceptable']
cola_syn = cola_syn.map(convert_labels_to_int)
cola_syn = cola_syn.map(
    lambda x: {"text_label": [cola_syn_classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

cola_syn = cola_syn.map(
    remove_syn_col,
    batched=True,
    num_proc=1,
)
cola_syn = cola_syn.rename_column('text', 'sentence')
print('syn2', cola_syn['train'][0])

#sst
sst_path = 'synthetic-data/SST-2/train.json'
sst_syn = load_dataset('json', data_files=sst_path)
sst_syn_classes = ['negative', 'positive']
sst_syn = sst_syn.map(convert_labels_to_int)
sst_syn = sst_syn.map(
    lambda x: {"text_label": [sst_syn_classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

sst_syn = sst_syn.map(
    remove_syn_col,
    batched=True,
    num_proc=1,
)
sst_syn = sst_syn.rename_column('text', 'sentence')
#qqp
def remove_syn_col2(ex):
    del ex['score']
    del ex['start_prompt']
    del ex['conj_prompt']
    return ex 

qqp_path = 'synthetic-data/QQP/train.json'
qqp_syn = load_dataset('json', data_files=qqp_path)
qqp_syn_classes = ['not_duplicate', 'duplicate']
qqp_syn = qqp_syn.map(convert_labels_to_int)
qqp_syn = qqp_syn.map(
    lambda x: {"text_label": [qqp_syn_classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

qqp_syn = qqp_syn.map(
    remove_syn_col2,
    batched=True,
    num_proc=1,
)
qqp_syn = qqp_syn.rename_column('text1', 'question1')
qqp_syn = qqp_syn.rename_column('text2', 'question2')
#mrpc
mrpc_path = 'synthetic-data/MRPC/train.json'
mrpc_syn = load_dataset('json', data_files=mrpc_path)
mrpc_syn_classes = ['not_equivalent', 'equivalent']
mrpc_syn = mrpc_syn.map(convert_labels_to_int)
mrpc_syn = mrpc_syn.map(
    lambda x: {"text_label": [mrpc_syn_classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

mrpc_syn = mrpc_syn.map(
    remove_syn_col2,
    batched=True,
    num_proc=1,
)
mrpc_syn = mrpc_syn.rename_column('text1', 'sentence1')
mrpc_syn = mrpc_syn.rename_column('text2', 'sentence2')
#mnli
label_mapping = {
    'entailment': 0,
    'neutral': 1,
    'contradiction': 2
}


def convert_labels_to_int_mnli(example):
    example['label'] = label_mapping[example['label']]
    return example

mnli_path = 'synthetic-data/MNLI/train.json'
mnli_syn = load_dataset('json', data_files=mnli_path)
mnli_syn_classes = ['entailment', 'neutral', 'contradiction']
mnli_syn = mnli_syn.map(convert_labels_to_int_mnli)
mnli_syn = mnli_syn.map(
    lambda x: {"text_label": [mnli_syn_classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

mnli_syn = mnli_syn.map(
    remove_syn_col2,
    batched=True,
    num_proc=1,
)
mnli_syn = mnli_syn.rename_column('text1', 'premise')
mnli_syn = mnli_syn.rename_column('text2', 'hypothesis')
#qnli
label_mapping2 = {
    'entailment': 0,
    'not_entailment': 1,
}


def convert_labels_to_int2(example):
    example['label'] = label_mapping2[example['label']]
    return example

qnli_path = 'synthetic-data/QNLI/train.json'
qnli_syn = load_dataset('json', data_files=qnli_path)
qnli_syn_classes = ['entailment', 'not_entailment']
qnli_syn = qnli_syn.map(convert_labels_to_int2)
qnli_syn = qnli_syn.map(
    lambda x: {"text_label": [qnli_syn_classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

qnli_syn = qnli_syn.map(
    remove_syn_col2,
    batched=True,
    num_proc=1,
)
qnli_syn = qnli_syn.rename_column('text1', 'question')
qnli_syn = qnli_syn.rename_column('text2', 'sentence')

#rte
rte_path = 'synthetic-data/RTE/train.json'
rte_syn = load_dataset('json', data_files=rte_path)
rte_syn_classes = ['entailment', 'not_entailment']
rte_syn = rte_syn.map(convert_labels_to_int2)
rte_syn = rte_syn.map(
    lambda x: {"text_label": [rte_syn_classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

rte_syn = rte_syn.map(
    remove_syn_col2,
    batched=True,
    num_proc=1,
)
rte_syn = rte_syn.rename_column('text1', 'sentence1')
rte_syn = rte_syn.rename_column('text2', 'sentence2')

#wnli
label_mapping3 = {
    'entailment': 1,
    'not entailment': 0,
}

# Define the function to convert the 'lab' column

"""COLA dataset"""
print("loading")
cola_dataset = load_dataset("glue", "cola", cache_dir="./data")

print("loading done")
cola_classes = cola_dataset["train"].features["label"].names

print('cola_dataset', cola_dataset['train'][0])
cola_dataset = cola_dataset.map(
    lambda x: {"text_label": [cola_classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)
cola_dataset = cola_dataset.map(
    Remove_Columns,
    batched=True,
    num_proc=1,
)

# t5-large tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-large', model_max_length=512)

# t5-small tokenizer
#tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
#tokenizer = BertTokenizer.from_pretrained(model_name)

def preprocess_function_cola(examples):
    # inputs = examples["sentence"]
    prefix = f"Sentence acceptability analysis: "
    inputs = [prefix + text for text in examples["sentence"]]
    targets = examples["label"]
    targets = ['1' if label else '2' for label in targets]
    targets = [str(label) for label in targets]
    model_inputs = tokenizer(
        inputs,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt")
    labels = tokenizer(
        targets,
        max_length=2,
        padding="max_length",
        truncation=True,
        return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs


cola_processed_datasets = cola_dataset.map(
    preprocess_function_cola,
    batched=True,
    num_proc=1,
    remove_columns=cola_dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

# Create COLA DataLoader
cola_train_dataset = cola_processed_datasets["train"]
cola_eval_dataset = cola_processed_datasets["validation"]
c = 0

cola_train_dataloader = DataLoader(
    cola_train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)
print("len train loader:", len(cola_train_dataloader) )

cola_eval_dataloader = DataLoader(cola_eval_dataset, collate_fn=default_data_collator, batch_size=batch_size,
                                  pin_memory=True)

cola_metric = evaluate.load('glue', "cola")
#
cola_syn_processed_datasets = cola_syn.map(
    preprocess_function_cola,
    batched=True,
    num_proc=1,
    remove_columns=cola_syn["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

# Create COLA proxy DataLoader
cola_syn_dataset = cola_syn_processed_datasets["train"]
c = 0

cola_syn_dataloader = DataLoader(
    cola_syn_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)


"""SST-2 Dataset"""
sst_dataset = load_dataset("glue", "sst2", cache_dir="./data")
del sst_dataset["test"]
classes_sst = sst_dataset["train"].features["label"].names
print('sst', sst_dataset["train"][0])
print(classes_sst)
sst_dataset = sst_dataset.map(
    lambda x: {"text_label": [classes_sst[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

sst_dataset = sst_dataset.map(
    Remove_Columns,
    batched=True,
    num_proc=1,
)


def preprocess_function_sst(examples):
    # inputs = examples["sentence"]
    prefix = "Movie review sentiment analysis: How do you feel about "
    inputs = [prefix + text for text in examples["sentence"]]
    targets = examples["text_label"]
    targets = ['1' if label == 'positive' else '2' for label in targets]
    model_inputs = tokenizer(
        inputs,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt")
    labels = tokenizer(
        targets,
        max_length=2,
        padding="max_length",
        truncation=True,
        return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs


sst_processed_datasets = sst_dataset.map(
    preprocess_function_sst,
    batched=True,
    num_proc=1,
    remove_columns=sst_dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

# Create sst-2 DataLoader
sst_train_dataset = sst_processed_datasets["train"]
sst_eval_dataset = sst_processed_datasets["validation"]

sst_train_dataloader = DataLoader(
    sst_train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)
sst_eval_dataloader = DataLoader(
    sst_eval_dataset,
    collate_fn=default_data_collator,
    batch_size=batch_size,
    pin_memory=True)
##
sst_syn_processed_datasets = sst_syn.map(
    preprocess_function_sst,
    batched=True,
    num_proc=1,
    remove_columns=sst_syn["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

# Create sst-2 proxy DataLoader
sst_syn_dataset = sst_syn_processed_datasets["train"]

sst_syn_dataloader = DataLoader(
    sst_syn_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)


#MRPC Dataset
mrpc_dataset = load_dataset("glue", "mrpc", cache_dir="./data")
del mrpc_dataset["test"]
print('mrpc', mrpc_dataset["train"][0])
classes = mrpc_dataset["train"].features["label"].names
print(classes)
mrpc_dataset = mrpc_dataset.map(
    lambda x: {"text_label": [classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

mrpc_dataset = mrpc_dataset.map(
    Remove_Columns,
    batched=True,
    num_proc=1,
)


def preprocess_function_mrpc(examples):
    inputs = [f"equivalent: mrpc sentence1: {text1} sentence2: {text2}" for text1, text2 in
              zip(examples['sentence1'], examples['sentence2'])]

    targets = examples["label"]
    targets = ['1' if label else '2' for label in targets]
    model_inputs = tokenizer(
        inputs,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt")
    labels = tokenizer(
        targets,
        max_length=2,
        padding="max_length",
        truncation=True,
        return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs


mrpc_processed_datasets = mrpc_dataset.map(
    preprocess_function_mrpc,
    batched=True,
    num_proc=1,
    remove_columns=mrpc_dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

mrpc_train_dataset = mrpc_processed_datasets["train"]
mrpc_eval_dataset = mrpc_processed_datasets["validation"]

mrpc_train_dataloader = DataLoader(
    mrpc_train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)
mrpc_eval_dataloader = DataLoader(
    mrpc_eval_dataset,
    collate_fn=default_data_collator,
    batch_size=batch_size,
    pin_memory=True)
mrpc_metric = evaluate.load('glue', "mrpc")
##
mrpc_syn_processed_datasets = mrpc_syn.map(
    preprocess_function_mrpc,
    batched=True,
    num_proc=1,
    remove_columns=mrpc_syn["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

mrpc_syn_dataset = mrpc_syn_processed_datasets["train"]

mrpc_syn_dataloader = DataLoader(
    mrpc_syn_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)

#QQP Dataset
qqp_dataset = load_dataset("glue", "qqp", cache_dir="./data")
del qqp_dataset["test"]
print('qqp', qqp_dataset["train"][0])
qqp_classes = qqp_dataset["train"].features["label"].names
print(qqp_classes)
qqp_dataset = qqp_dataset.map(
    lambda x: {"text_label": [qqp_classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

qqp_dataset = qqp_dataset.map(
    Remove_Columns,
    batched=True,
    num_proc=1,
)


def preprocess_function_qqp(examples):
    inputs = [
        f"duplicate: qqp sentence1: {text1} sentence2: {text2}" for text1,
        text2 in zip(
            examples['question1'],
            examples['question2'])]

    targets = examples["label"]
    targets = ['1' if label else '2' for label in targets]
    model_inputs = tokenizer(
        inputs,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt")
    labels = tokenizer(
        targets,
        max_length=2,
        padding="max_length",
        truncation=True,
        return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs


qqp_processed_datasets = qqp_dataset.map(
    preprocess_function_qqp,
    batched=True,
    num_proc=1,
    remove_columns=qqp_dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

# Create QQP DataLoader
qqp_train_dataset = qqp_processed_datasets["train"]
qqp_eval_dataset = qqp_processed_datasets["validation"]

qqp_train_dataloader = DataLoader(
    qqp_train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)
qqp_eval_dataloader = DataLoader(
    qqp_eval_dataset,
    collate_fn=default_data_collator,
    batch_size=batch_size,
    pin_memory=True)
qqp_metric = evaluate.load('glue', "qqp")

## QQP proxy dataloader
qqp_syn_processed_datasets = qqp_syn.map(
    preprocess_function_qqp,
    batched=True,
    num_proc=1,
    remove_columns=qqp_syn["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

# Create QQP DataLoader
qqp_syn_dataset = qqp_syn_processed_datasets["train"]

qqp_syn_dataloader = DataLoader(
    qqp_syn_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)


#MNLI Dataset
mnli_dataset = load_dataset("glue", "mnli", cache_dir="./data")
del mnli_dataset["validation_mismatched"]
del mnli_dataset["test_matched"]
del mnli_dataset["test_mismatched"]
print('mnli', mnli_dataset["train"][0])
mnli_classes = mnli_dataset["train"].features["label"].names
print(mnli_classes)
mnli_dataset = mnli_dataset.map(
    lambda x: {"text_label": [mnli_classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)
mnli_dataset = mnli_dataset.map(
    Remove_Columns,
    batched=True,
    num_proc=1,
)


def preprocess_function_mnli(examples):
    inputs = [
        f"language inference analysis: mnli premise: {text1} hypothesis: {text2}" for text1,
        text2 in zip(
            examples['premise'],
            examples['hypothesis'])]

    targets = examples["label"]
    targets = [str(label + 1) for label in targets]
    model_inputs = tokenizer(
        inputs,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt")
    labels = tokenizer(
        targets,
        max_length=2,
        padding="max_length",
        truncation=True,
        return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs


mnli_processed_datasets = mnli_dataset.map(
    preprocess_function_mnli,
    batched=True,
    num_proc=1,
    remove_columns=mnli_dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

# Create MNLI DataLoader
mnli_train_dataset = mnli_processed_datasets["train"]
mnli_eval_dataset = mnli_processed_datasets["validation_matched"]
mnli_train_dataloader = DataLoader(
    mnli_train_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=batch_size,
    pin_memory=True)
mnli_eval_dataloader = DataLoader(
    mnli_eval_dataset,
    collate_fn=default_data_collator,
    batch_size=batch_size,
    pin_memory=True)
mnli_metric = evaluate.load('glue', "mnli")
#
mnli_syn_processed_datasets = mnli_syn.map(
    preprocess_function_mnli,
    batched=True,
    num_proc=1,
    remove_columns=mnli_syn["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

mnli_syn_dataset = mnli_syn_processed_datasets["train"]
mnli_syn_dataloader = DataLoader(
    mnli_syn_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=batch_size,
    pin_memory=True)


#QNLI Dataset
qnli_dataset = load_dataset("glue", "qnli", cache_dir="./data")
del qnli_dataset["test"]
qnli_classes = qnli_dataset["train"].features["label"].names
print('qnli', qnli_dataset["train"][0])
print(qnli_classes)
qnli_dataset = qnli_dataset.map(
    lambda x: {"text_label": [qnli_classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)
qnli_dataset = qnli_dataset.map(
    Remove_Columns,
    batched=True,
    num_proc=1,
)


def preprocess_function_qnli(examples):
    inputs = [
        f"question answering: qnli question: {text1} sentence: {text2}" for text1,
        text2 in zip(
            examples['question'],
            examples['sentence'])]

    targets = examples["label"]
    targets = ['1' if label else '2' for label in targets]
    model_inputs = tokenizer(
        inputs,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt")
    labels = tokenizer(
        targets,
        max_length=2,
        padding="max_length",
        truncation=True,
        return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs


qnli_processed_datasets = qnli_dataset.map(
    preprocess_function_qnli,
    batched=True,
    num_proc=1,
    remove_columns=qnli_dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

qnli_train_dataset = qnli_processed_datasets["train"]
qnli_eval_dataset = qnli_processed_datasets["validation"]
qnli_train_dataloader = DataLoader(
    qnli_train_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=batch_size,
    pin_memory=True)
qnli_eval_dataloader = DataLoader(
    qnli_eval_dataset,
    collate_fn=default_data_collator,
    batch_size=batch_size,
    pin_memory=True)
qnli_metric = evaluate.load('glue', "qnli")
##
qnli_syn_processed_datasets = qnli_syn.map(
    preprocess_function_qnli,
    batched=True,
    num_proc=1,
    remove_columns=qnli_syn["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

qnli_syn_dataset = qnli_syn_processed_datasets["train"]

qnli_syn_dataloader = DataLoader(
    qnli_syn_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=batch_size,
    pin_memory=True)


#RTE Dataset
rte_dataset = load_dataset("glue", "rte", cache_dir="./data")
del rte_dataset["test"]
rte_classes = rte_dataset["train"].features["label"].names
print('rte', rte_dataset["train"][0])
print(rte_classes)
rte_dataset = rte_dataset.map(
    lambda x: {"text_label": [classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

rte_dataset = rte_dataset.map(
    Remove_Columns,
    batched=True,
    num_proc=1,
)


def preprocess_function_rte(examples):
    inputs = [f"entailment: rte sentence1: {text1} sentence2: {text2}" for text1, text2 in
              zip(examples['sentence1'], examples['sentence2'])]

    targets = examples["label"]
    targets = ['1' if label else '2' for label in targets]
    model_inputs = tokenizer(
        inputs,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt")
    labels = tokenizer(
        targets,
        max_length=2,
        padding="max_length",
        truncation=True,
        return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs


rte_processed_datasets = rte_dataset.map(
    preprocess_function_rte,
    batched=True,
    remove_columns=rte_dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

rte_train_dataset = rte_processed_datasets["train"]
rte_eval_dataset = rte_processed_datasets["validation"]
rte_train_dataloader = DataLoader(
    rte_train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)
rte_eval_dataloader = DataLoader(
    rte_eval_dataset,
    collate_fn=default_data_collator,
    batch_size=batch_size,
    pin_memory=True)
##
rte_syn_processed_datasets = rte_syn.map(
    preprocess_function_rte,
    batched=True,
    remove_columns=rte_syn["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

rte_syn_dataset = rte_syn_processed_datasets["train"]
rte_syn_dataloader = DataLoader(
    rte_syn_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)


def split_data(data: DataLoader, clients: int = 2) -> List[DataLoader]:
    """
    Splits the training dataset into equally sized client portions.
    """
    
    dataset = data.dataset
    print("len cola", len(dataset))
    split_size = len(dataset) // clients
    remainder = len(dataset) % clients
    split_lengths = [split_size + 1 if i <
                     remainder else split_size for i in range(clients)]
    #print("split length", split_lengths)
    splits = random_split(dataset, split_lengths)
    split_dataloaders = [
        DataLoader(
            split,
            collate_fn=default_data_collator,
            batch_size=data.batch_size,
            pin_memory=True) for split in splits]
    #print("spliter:", split_dataloaders)
    return split_dataloaders, split_lengths


# 2k data per client
num_clients_cola = 4 
num_clients_sst = 30
num_clients_mrpc = 2
num_clients_qqp = 180
num_clients_mnli = 200
num_clients_qnli = 50
num_clients_rte = 1
#
num_syn_cola = 6 
num_syn_sst = 6
num_syn_mrpc = 6
num_syn_qqp = 6
num_syn_mnli = 6
num_syn_qnli = 6
num_syn_rte = 6

number_of_syns = 42

number_of_clients = 467 
num_selected_clients = 163 + 30
cola_client_dataloaders, len_data_cola = split_data(
    cola_train_dataloader,
    clients=num_clients_cola)

print("len client loader:", len(cola_client_dataloaders) , len(cola_client_dataloaders[0]))


sst_client_dataloaders, len_data_sst = split_data(
    sst_train_dataloader,
    clients=num_clients_sst)
mrpc_client_dataloaders, len_data_mrpc = split_data(
    mrpc_train_dataloader,
    clients=num_clients_mrpc)
qqp_client_dataloaders, len_data_qqp = split_data(
    qqp_train_dataloader,
    clients=num_clients_qqp)
mnli_client_dataloaders, len_data_mnli = split_data(
    mnli_train_dataloader,
    clients=num_clients_mnli)
qnli_client_dataloaders, len_data_qnli = split_data(
    qnli_train_dataloader,
    clients=num_clients_qnli)
rte_client_dataloaders, len_data_rte = split_data(
    rte_train_dataloader,
    clients=num_clients_rte)

#
cola_syn_client_dataloaders, len_syn_cola = split_data(
    cola_syn_dataloader,
    clients=num_syn_cola)
sst_syn_client_dataloaders, len_syn_sst = split_data(
    sst_syn_dataloader,
    clients=num_syn_sst)
mrpc_syn_client_dataloaders, len_syn_mrpc = split_data(
    mrpc_syn_dataloader,
    clients=num_syn_mrpc)
qqp_syn_client_dataloaders, len_syn_qqp = split_data(
    qqp_syn_dataloader,
    clients=num_syn_qqp)
mnli_syn_client_dataloaders, len_syn_mnli = split_data(
    mnli_syn_dataloader,
    clients=num_syn_mnli)
qnli_syn_client_dataloaders, len_syn_qnli = split_data(
    qnli_syn_dataloader,
    clients=num_syn_qnli)
rte_syn_client_dataloaders, len_syn_rte = split_data(
    rte_syn_dataloader,
    clients=num_syn_rte)


len_data_syn = len_syn_cola + len_syn_sst + len_syn_mrpc + len_syn_qqp + len_syn_mnli + len_syn_qnli + len_syn_rte
len_data_clients = len_data_cola + len_data_sst + len_data_mrpc+ len_data_qqp + len_data_mnli + len_data_qnli + len_data_rte +\
    len_syn_cola + len_syn_sst + len_syn_mrpc + len_syn_qqp + len_syn_mnli + len_syn_qnli + len_syn_rte
print('len data clients', len(len_data_clients))

#print('len data clients', len(len_data_clients))
client_dataloaders = cola_client_dataloaders +\
    sst_client_dataloaders + mrpc_client_dataloaders + \
    qqp_client_dataloaders + mnli_client_dataloaders + \
    qnli_client_dataloaders + rte_client_dataloaders  + \
    cola_syn_client_dataloaders +\
    sst_syn_client_dataloaders + mrpc_syn_client_dataloaders + \
    qqp_syn_client_dataloaders + mnli_syn_client_dataloaders + \
    qnli_syn_client_dataloaders + rte_syn_client_dataloaders
# for real proxy we dont add syn loaders , add it for syn (in code we separate them)

proxy_dataloaders = cola_syn_client_dataloaders  +\
    sst_syn_client_dataloaders + mrpc_syn_client_dataloaders + \
    qqp_syn_client_dataloaders + mnli_syn_client_dataloaders + \
    qnli_syn_client_dataloaders + rte_syn_client_dataloaders

"""Training"""
peft_config = PrefixTuningConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    inference_mode=False,
    num_virtual_tokens=num_virtual_tokens)
global_acc_cola = []
global_acc_sst = []
global_acc_mrpc = []
global_acc_stsb = []
global_acc_qqp = []
global_acc_mnli = []
global_acc_qnli = []
global_acc_rte = []
global_acc_wnli = []

number_cola_rounds = []
number_sst_rounds = []
number_mrpc_rounds = []
number_stsb_rounds = []
number_qqp_rounds = []
number_mnli_rounds = []
number_qnli_rounds = []
number_rte_rounds = []
number_wnli_rounds = []

total_regular_indices = list(range(0,467)) 
total_proxy_indices = list(range(467,509)) 

print('loader l', len(client_dataloaders))
print(len(total_regular_indices))
print(len(total_proxy_indices))


#rand = np.random.choice(range(1,11), 3, replace=False) 
#rand = sorted(rand)
miss_rounds = range(25,35)


total_cl_indices = list(range(509))
client_half_state = [0] * len(total_cl_indices)

loss_list = []

dataset_to_indices = {}
#avg window method
def detect_anomalies(data, ref_window=2, drop_ratio=0.5, jump_ratio=1.5): #use wg var
    anomalies = []
    normal_ref_mean = None
    anomaly_flag = False

    for i in range(ref_window, len(data)):
        if not anomaly_flag:
            ref_mean = sum(data[i - ref_window:i]) / ref_window
        else:
            ref_mean = normal_ref_mean

        if data[i] < ref_mean * drop_ratio : 
            anomalies.append(i)
            anomaly_flag = True
        else:
        
            normal_ref_mean = sum(data[i - ref_window:i]) / ref_window
            anomaly_flag = False
    return anomalies  
      
for Round in range(1, 50 + 1):
    c_cl = {c: 0 for c in range(number_of_clients)}
    print(f"\n**************Running round {Round}***********************\n")

    round_dir = f'logs/FedMemo/round{Round}/'

    model_dir = round_dir + 'model_parameters'
    os.makedirs(model_dir, exist_ok=True)
    client_grad = {}
    c_weights = {}
    
    cl_acc_dict = {
        'cola': {'0': [], '1': []}, 
        'sst': {'0': [], '1': []},  
        'mnli': {'0': [], '1': [], '2': []}, 
        'qnli': {'0': [], '1': []},  
        'rte': {'0': [], '1': []},  
        'wnli': {'0': [], '1': []}, 
        'mrpc': {'0': [], '1': []}, 
    }  

    selected_clients_reg = [client_dataloaders[index] for index in sorted(random.sample(total_regular_indices, k=163))] 
    proxy_indices = sorted(random.sample(total_proxy_indices, k=30))
    selected_proxy = [client_dataloaders[index] for index in proxy_indices]
    selected_clients = selected_clients_reg + selected_proxy
   
    reg_indices = [index for index, client_loader in enumerate(client_dataloaders) if client_loader in selected_clients_reg]
    proxy_indices = [index for index, client_loader in enumerate(client_dataloaders) if client_loader in selected_proxy]
    print("sel reg:", reg_indices, "sel proxy:", proxy_indices)
    selected_indices = reg_indices + proxy_indices
    sel_2step_indices = [reg_indices, proxy_indices]
    sel_2step_indices[0].sort()
    sel_2step_indices[1].sort()
    
    selected_indices.sort()
    print("sel all:", selected_indices)
    note_clients_ids = {'cola':[], 'sst':[], 'stsb':[], 'qnli':[], 'wnli':[], 'rte':[] }
    
    print('sel indices', sel_2step_indices)
    for f, sel_indices in enumerate(sel_2step_indices):

        if f == 0:
            flag_reg = 1
            flag_proxy = 0
            flag_step2 = 0
        else:
            flag_proxy = 1
            flag_step2 = 1
            flag_reg = 0
        if Round not in miss_rounds and flag_proxy == 1:
            continue        
        for index in sel_indices:
            client_id = index
            client_dataloader = client_dataloaders[index]
            total_batches = len(client_dataloader)

            print(len(client_dataloader))
            print(f"\n***Training on Client {client_id }****\n")

            model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)  # default before reducing attention heads
            model = get_peft_model(model, peft_config)

            round_model_copy = None
            if Round == 1:

                initial_state_dict = model.state_dict()
                torch.save(initial_state_dict, "./init_model_large.pt")
                init_model_dict = torch.load(
                    r"./init_model_large.pt")
                
                for name, param in model.named_parameters():
                    if name in init_model_dict:
                        param.data = init_model_dict[name].clone().detach()
                round_model_copy = copy.deepcopy(model).to(device)
                
            if Round > 1:
                for name, param in model.named_parameters():
                    if name in prev_global_model_dict:
                        param.data = prev_global_model_dict[name].clone().detach()
                
            if Round == last_round and Round > 1:
                global_model_copy = copy.deepcopy(model)
            
                global_model_copy.to(device)

            model.print_trainable_parameters()

            client_using_dataset = None
            if len(client_dataloader) == len(cola_client_dataloaders[0]) or len(client_dataloader) == len(cola_syn_client_dataloaders[0]):
                client_using_dataset = 'cola'
            if len(client_dataloader) == len(sst_client_dataloaders[0]) or len(client_dataloader) == len(sst_syn_client_dataloaders[0]):
                client_using_dataset = 'sst'
            if len(client_dataloader) == len(mrpc_client_dataloaders[0]) or len(client_dataloader) == len(mrpc_syn_client_dataloaders[0]):
                client_using_dataset = 'mrpc'
            if len(client_dataloader) == len(qqp_client_dataloaders[0]) or len(client_dataloader) == len(qqp_syn_client_dataloaders[0]):
                client_using_dataset = 'qqp'
            if len(client_dataloader) == len(mnli_client_dataloaders[0]) or len(client_dataloader) == len(mnli_syn_client_dataloaders[0]):
                client_using_dataset = 'mnli'
            if len(client_dataloader) == len(qnli_client_dataloaders[0]) or len(client_dataloader) == len(qnli_syn_client_dataloaders[0]):
                client_using_dataset = 'qnli'
            if len(client_dataloader) == len(rte_client_dataloaders[0]) or len(client_dataloader) == len(rte_syn_client_dataloaders[0]):
                client_using_dataset = 'rte'
            print('client data', client_using_dataset)

            optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
            lr_scheduler = get_linear_schedule_with_warmup(
                optimizer=optimizer,
                num_warmup_steps=0,
                num_training_steps=(len(client_dataloader) * local_epochs),
            )

            # Train the model on the current client's dataset
            model = model.to(device)

            grads = None
            for epoch in range(local_epochs):
                model.train()
                total_loss = 0
                for step, batch in enumerate(tqdm(client_dataloader)):

                    if Round in miss_rounds and client_id in reg_indices: # apply desired missing data
                        batch_x = []
                        batch_att = []
                        batch_y = []
                        c = 0
                        for j in range(len(batch['labels'])):
                            if batch['labels'][j].tolist() == [204, 1] or batch['labels'][j].tolist() == [220, 1]:
                                c += 1
                                c_cl[client_id] += 1    # number of remained data 
                                batch_x.append(batch['input_ids'][j])
                                batch_att.append(batch['attention_mask'][j])
                                batch_y.append(batch['labels'][j])
                       
                        if c > 0:
                            b_x = torch.stack(batch_x)
                            b_y = torch.stack(batch_y)
                            b_att = torch.stack(batch_att)
                            b_x, b_att, b_y = b_x.to(device), b_att.to(device), b_y.to(device)
                        else:
                            continue
                            
                    else:
                        b_x = batch['input_ids'].to(device)
                        b_att = batch['attention_mask'].to(device)
                        b_y = batch['labels'].to(device)
                    
                    outputs = model(input_ids=b_x, attention_mask=b_att, labels=b_y)
                    
                    if Round == 1:
                        proximal_term = 0.0
                        for w, w_t in zip(model.parameters(),
                                          round_model_copy.parameters()):
                            proximal_term += (w - w_t).norm(2)
                        loss = outputs.loss + (mu / 2) * proximal_term
                        # loss = outputs.loss
                        total_loss += loss.detach().float()
                        loss.backward()
                        optimizer.step()
                        lr_scheduler.step()
                        for param in model.parameters():
                            if param.requires_grad:
                                if grads is None:
                                    grads = param.grad.clone().detach()
                                else:
                                    grads += param.grad.clone().detach()
                        optimizer.zero_grad()
                    else:
                        proximal_term = 0.0
                        for w, w_t in zip(model.parameters(),
                                          global_model_copy.parameters()):
                            proximal_term += (w - w_t).norm(2)
                        loss = outputs.loss + (mu / 2) * proximal_term
                        total_loss += loss.detach().float()
                        loss.backward()
                        optimizer.step()
                        lr_scheduler.step()
                        for param in model.parameters():
                            if param.requires_grad:
                                if grads is None:
                                    grads = param.grad.clone().detach()
                                else:
                                    grads += param.grad.clone().detach()
                                            
                        optimizer.zero_grad()

                       
                
            client_param_values = {}
            for name, param in model.named_parameters():
                if param.requires_grad:
                    print(f"Parameter Name: {name}, Shape: {param.shape}")
                    client_param_values[name] = param.data
                    
            c_weights[client_id] = [param for param in model.parameters() if param.requires_grad]
            
            client_grad[client_id] = grads#[param.grad for param in model.parameters() if param.requires_grad]
        
            torch.save(
                client_param_values,
                f'{model_dir}/client_{client_id}_model_{client_using_dataset}.pt')

            # delete the model and clear the cache
            del model
            torch.cuda.empty_cache()

            is_cache_cleared = torch.cuda.memory_allocated() == 0
            is_memory_cleared = torch.cuda.memory_summary(device=None).strip()
            print(f"[Client {client_id}] Allocated: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
            print(f"[Client {client_id}] Reserved: {torch.cuda.memory_reserved() / 1024 ** 2:.2f} MB")
        
        global_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
        global_model = get_peft_model(global_model, peft_config)
        global_model.to(device)

        # Initialize a variable to store the accumulated parameters
        global_params = {}
        weights_proxy = {}
        proxy_grad = {}
        glob_grad = {}
            
        # Load the client models' parameters and accumulate them

        filelist = os.listdir(model_dir)
        filelist.sort(key=lambda x: int(x[7:7 + x[7:].find('_')]) if x[7:7 + x[7:].find('_')].isdigit() else float('inf')
)
        #print(filelist)
        print('------------------------------------')
        total_num_samples, cola_num, sst_num, mrpc_num, stsb_num, qqp_num, mnli_num, qnli_num, rte_num, wnli_num = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        for param_id, params in enumerate(filelist):
            if params[-7:-3] == r'cola':
                cola_num += 1
            if params[-6:-3] == r'sst':
                sst_num += 1
            if params[-7:-3] == r'mrpc':
                mrpc_num += 1
            if params[-6:-3] == r'qqp':
                qqp_num += 1
            if params[-7:-3] == r'mnli':
                mnli_num += 1
            if params[-7:-3] == r'qnli':
                qnli_num += 1
            if params[-6:-3] == r'rte':
                rte_num += 1

        print(
            f'number of cola: {cola_num} number of sst:{sst_num} number of mrpc:{mrpc_num} number of qqp:{qqp_num} number of mnli:{mnli_num} number of qnli:{qnli_num} number of rte:{rte_num}')
        
        number_cola_rounds.append(cola_num)
        number_sst_rounds.append(sst_num)
        number_mrpc_rounds.append(mrpc_num)
        number_qqp_rounds.append(qqp_num)
        number_mnli_rounds.append(mnli_num)
        number_qnli_rounds.append(qnli_num)
        number_rte_rounds.append(rte_num)
        for t,c in enumerate(c_cl):
            if c_cl[t] != 0:
                c_cl[t] = len_data_clients[t] - c_cl[t]   
        if flag_reg == 1:        
            for id in reg_indices:
                total_num_samples += (len_data_clients[id] - c_cl[id]) 
        # dont do this for 2 step aggregation
        if Round in miss_rounds and flag_proxy == 1 :
            for id in proxy_indices:
                total_num_samples += (len_data_clients[id]) #cola_num * len(cola_client_dataloaders[0])   +\
        

        total_num_proxy = 0
        for id in proxy_indices:
            total_num_proxy += (len_data_clients[id]) #cola_num * len(cola_client_dataloaders[0]) 

        #aggregate clients/proxy (at step2)
        print(f"aggregate {number_of_clients} clients")
        aggregate_times = 0
        fake_id = 0
        file_dict = {}
        for file in filelist:
            parts = file.split('_')
            if 'global' in file or 'glue' in file:
                continue
            client_id = int(parts[1])  # Extract the numerical ID
            print(client_id)
            file_dict[client_id] = file

        for client_id in sel_indices: #number_of_clients
            print("c id:", client_id)
            
            
            if client_id in sel_indices: 
                
                if client_id in file_dict:
                    filename = file_dict[client_id]
                    client_dataset = filename.split('_')[-1]
                    print('data', client_dataset)
                    if client_dataset.endswith('.pt'):
                        client_dataset = client_dataset[:-3]
                client_model_path = os.path.join(
                    model_dir, f'client_{client_id}_model_{client_dataset}.pt')
                print(client_model_path)
                fake_id += 1
                client_model_dict = torch.load(client_model_path)
                if client_id in reg_indices:
                    weight = (len_data_clients[client_id] - c_cl[client_id] )/ total_num_samples #(len(cola_client_dataloaders[0]) - c_cl[client_id] )/ total_num_samples
                    wp = weight

                    print(
                        f'this client is using {client_dataset} dataset, the weight of this client is {weight}')
                if client_id in proxy_indices:
                    weight = (len_data_clients[client_id] )/ total_num_proxy
                if Round in rand and client_id in proxy_indices:
                    weight_help = (len_data_clients[client_id] )/ total_num_samples # in 2 steps we dont total_num_samples

                # Accumulate the client model parameters
                if client_id in reg_indices:
                    for param_name, param in client_model_dict.items():
                        if param_name not in global_params:
                            global_params[param_name] = param.clone().detach() * weight
                        else:
                            global_params[param_name] += param.clone().detach() * weight
                    aggregate_times += 1
                if Round in miss_rounds and client_id in proxy_indices:  
                    for param_name, param in client_model_dict.items():
                        if param_name not in global_params:
                            global_params[param_name] = param.clone().detach() * weight_help
                        else:
                            global_params[param_name] += param.clone().detach() * weight_help


                    aggregate_times += 1

                if client_id in proxy_indices:                
                    for param_name, param in client_model_dict.items():
                        if param_name not in weights_proxy:
                            weights_proxy[param_name] = param.clone().detach() * weight
                        else:
                            weights_proxy[param_name] += param.clone().detach() * weight


        print(f'aggregated {aggregate_times} ')
        print('\n')
        

        for name, param in global_model.named_parameters():
            if name in global_params:
                param.data = global_params[name].clone().detach()
        # global_model.load_state_dict(global_params)
        
        ##################################
        flattened_glob = global_params['prompt_encoder.default.embedding.weight'].reshape(-1)
        
        flattened_glob_cpu = flattened_glob.cpu().numpy()
        #flattened_proxy_cpu = flattened_proxy.cpu().numpy()

        gl_w = flattened_glob_cpu
        #pr_w = flattened_proxy_cpu
        if Round == 2:
            w_grad0 = gl_w - gl_w_prev  
            #w_grad_p0 = pr_w - pr_w_prev
        if Round > 2:
            w_grad1 = gl_w - gl_w_prev
            #w_grad_p1 = pr_w - pr_w_prev
            w_grad0 = w_grad1
            #w_grad_p0 = w_grad_p1
        if Round > 1:
            #grad_cos = 1-spatial.distance.cosine(w_grad0, w_grad_p0 )
            grad_var = np.var(w_grad0)  ##weight update variance
            #grad_varp = np.var(w_grad_p0)
            wg_var.append(grad_var)
            #wg_varp.append(grad_varp)
            #wg_cos.append(grad_cos)

        gl_w_prev = gl_w
        #pr_w_prev = pr_w
        

        #print("wg_var = ", wg_var)
        
        print("-----------------------------------")
        print(" cola_r=", number_cola_rounds)
        print(" sst_r=", number_sst_rounds)
        print(" mrpc_r=", number_mrpc_rounds)
        print(" qqp_r=", number_qqp_rounds)
        print(" mnli_r=", number_mnli_rounds)
        print(" qnli_r=", number_qnli_rounds)
        print(" rte_r=", number_rte_rounds)
        #################################
        global_model_copy = copy.deepcopy(global_model)
        #if flag_step2 == 0:
        global_model_path = os.path.join(
            model_dir, f'global_model_{number_of_clients}_clients.pt')
        torch.save(global_params, global_model_path)
        prev_global_model_path = global_model_path
        if flag_step2 == 1 or Round not in rand:
            # Prediction in cola
            global_model.eval()
            cola_eval_preds = []

            for step, batch in enumerate(tqdm(cola_eval_dataloader)):
                batch = {k: v.to(device) for k, v in batch.items()}
                with torch.no_grad():
                    outputs = global_model(**batch)
                cola_eval_preds.extend(tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(),
                                                            skip_special_tokens=True))

            # Prediction in sst2
            global_model.eval()
            sst_eval_preds = []

            for step, batch in enumerate(tqdm(sst_eval_dataloader)):
                batch = {k: v.to(device) for k, v in batch.items()}
                with torch.no_grad():
                    outputs = global_model(**batch)
                sst_eval_preds.extend(tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(),
                                                            skip_special_tokens=True))

            # Prediction in mrpc
            global_model.eval()
            mrpc_eval_preds = []

            for step, batch in enumerate(tqdm(mrpc_eval_dataloader)):
                batch = {k: v.to(device) for k, v in batch.items()}
                with torch.no_grad():
                    outputs = global_model(**batch)
                mrpc_eval_preds.extend(tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(),
                                                            skip_special_tokens=True))

            

            # Prediction in qqp
            global_model.eval()
            qqp_eval_preds = []

            for step, batch in enumerate(tqdm(qqp_eval_dataloader)):
                batch = {k: v.to(device) for k, v in batch.items()}
                with torch.no_grad():
                    outputs = global_model(**batch)
                qqp_eval_preds.extend(tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(),
                                                            skip_special_tokens=True))

            # Prediction in mnli
            global_model.eval()
            mnli_eval_preds = []

            for step, batch in enumerate(tqdm(mnli_eval_dataloader)):
                batch = {k: v.to(device) for k, v in batch.items()}
                with torch.no_grad():
                    outputs = global_model(**batch)
                mnli_eval_preds.extend(tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(),
                                                            skip_special_tokens=True))

            # Prediction in qnli
            global_model.eval()
            qnli_eval_preds = []

            for step, batch in enumerate(tqdm(qnli_eval_dataloader)):
                batch = {k: v.to(device) for k, v in batch.items()}
                with torch.no_grad():
                    outputs = global_model(**batch)
                qnli_eval_preds.extend(tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(),
                                                            skip_special_tokens=True))

            # Prediction in rte
            global_model.eval()
            rte_eval_preds = []

            for step, batch in enumerate(tqdm(rte_eval_dataloader)):
                batch = {k: v.to(device) for k, v in batch.items()}
                with torch.no_grad():
                    outputs = global_model(**batch)
                rte_eval_preds.extend(tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(),
                                                            skip_special_tokens=True))

            
            # Evaluation on cola validation set
            class_correct = {0: 0, 1: 0}  # Correct predictions per class
            class_total = {0: 0, 1: 0}    # Total samples per class

            cola_correct = 0
            cola_total = 0

            for pred, true in zip(cola_eval_preds, cola_dataset["validation"]["label"]):
                cola_total += 1
                if pred.strip() == "":
                    continue  # Skip empty predictions
                if len(pred) != 1:
                    continue
                if pred not in ['1', '2']:
                    continue

                # Convert predictions to class indices (1 -> 1, 2 -> 0)
                pred_class = 0 if pred == '2' else 1
                true_class = true

                # Update total counts for true class
                class_total[true_class] += 1

                # Check for correctness and update counts
                if int(pred) == int(true):
                    cola_correct += 1
                    class_correct[true_class] += 1
                elif int(pred) == int(true) + 2:
                    cola_correct += 1
                    class_correct[true_class] += 1


            # Calculate global accuracy
            cola_global_accuracy = cola_correct / cola_total * 100
            print(
                f"\nRound: {Round} Global Accuracy for Model on COLA dataset on aggregating {number_of_clients} clients: {cola_global_accuracy}%")
            global_acc_cola.append(cola_global_accuracy)
            # Calculate per-class accuracies
            class_accuracies = {cls: (class_correct[cls] / class_total[cls] * 100 if class_total[cls] > 0 else 0)
                                for cls in class_total}

            ds_name = 'cola'  # Replace this with the actual dataset name being processed

            # Iterate over the computed class accuracies
            print("\nClass-wise Accuracies:")
            for cls, acc in class_accuracies.items():
                print('cls:', cls, type(cls))
                print(f"Class {cls}: {acc:.2f}%")
                # Append the accuracy to the corresponding class in the dictionary
                if ds_name in cl_acc_dict and cls in cl_acc_dict[ds_name]:
                    cl_acc_dict[ds_name][str(cls)].append(acc)
                else:
                    print(f"Warning: Class {cls} not found in {ds_name} in cl_acc_dict")
            # Evaluation on sst validation set
            # Initialize counters for each class
            class_correct_sst = {0: 0, 1: 0}  # Correct predictions per class
            class_total_sst = {0: 0, 1: 0}    # Total samples per class

            sst_correct = 0
            sst_total = 0

            for pred, true in zip(sst_eval_preds, sst_dataset["validation"]["label"]):
                sst_total += 1
                if pred not in ['1', '2']:
                    continue

                # Convert predictions to class indices (1 -> 1, 2 -> 0)
                pred_class = 0 if pred == '2' else 1
                true_class = true

                # Update total counts for true class
                class_total_sst[true_class] += 1

                # Check for correctness and update counts
                if int(pred.strip()) == int(true): 
                    sst_correct += 1
                    class_correct_sst[true_class] += 1
                elif int(pred.strip()) == int(true) + 2:
                    sst_correct += 1
                    class_correct_sst[true_class] += 1

            # Calculate global accuracy
            sst_global_accuracy = sst_correct / sst_total * 100
            global_acc_sst.append(sst_global_accuracy)
            print(
                f"\nRound: {Round} Global Accuracy for Model on SST dataset on aggregating {number_of_clients} clients: {sst_global_accuracy}%"
            )

            # Calculate per-class accuracies
            class_accuracies_sst = {
                cls: (class_correct_sst[cls] / class_total_sst[cls] * 100 if class_total_sst[cls] > 0 else 0)
                for cls in class_total_sst
            }

            ds_name = 'sst'  # Replace this with the actual dataset name being processed

            # Iterate over the computed class accuracies
            print("\nClass-wise Accuracies for SST:")
            for cls, acc in class_accuracies_sst.items():
                print(f"Class {cls}: {acc:.2f}%")
                # Append the accuracy to the corresponding class in the dictionary
                if ds_name in cl_acc_dict and str(cls) in cl_acc_dict[ds_name]:
                    cl_acc_dict[ds_name][str(cls)].append(acc)
                else:
                    print(f"Warning: Class {cls} not found in {ds_name} in cl_acc_dict")


            # Evaluation on mrpc validation set
            #print(f"\nmrpc dataset, global predictions: {mrpc_eval_preds}")
            #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
            # Initialize counters for each class
            class_correct_mrpc = {0: 0, 1: 0}  # Correct predictions per class
            class_total_mrpc = {0: 0, 1: 0}    # Total samples per class

            mrpc_correct = 0
            mrpc_total = 0

            for pred, true in zip(mrpc_eval_preds, mrpc_dataset["validation"]["label"]):
                mrpc_total += 1
                if pred not in ['1', '2']:
                    continue

                # Convert predictions to class indices 
                pred_class = int(pred)
                true_class = int(true)

                # Update total counts for true class
                class_total_mrpc[true_class] += 1

                # Check for correctness and update counts
                if pred_class == true_class:
                    mrpc_correct += 1
                    class_correct_mrpc[true_class] += 1
                elif int(pred) == int(true) + 2:
                    mrpc_correct += 1
                    class_correct_mrpc[true_class] += 1

            # Calculate global accuracy
            mrpc_global_accuracy = mrpc_correct / mrpc_total * 100
            print(
                f"\nRound: {Round} Global Accuracy for Model on MRPC dataset on aggregating {number_of_clients} clients: {mrpc_global_accuracy}%"
            )

            # Calculate per-class accuracies
            class_accuracies_mrpc = {
                cls: (class_correct_mrpc[cls] / class_total_mrpc[cls] * 100 if class_total_mrpc[cls] > 0 else 0)
                for cls in class_total_mrpc
            }

            ds_name = 'mrpc'  # Dataset name

            # Iterate over the computed class accuracies
            print("\nClass-wise Accuracies for MRPC:")
            for cls, acc in class_accuracies_mrpc.items():
                print(f"Class {cls}: {acc:.2f}%")
                # Append the accuracy to the corresponding class in the dictionary
                if ds_name in cl_acc_dict and str(cls) in cl_acc_dict[ds_name]:
                    cl_acc_dict[ds_name][str(cls)].append(acc)
                else:
                    print(f"Warning: Class {cls} not found in {ds_name} in cl_acc_dict")

            # Convert predictions for metric computation
            """mrpc_predictions = []
            for prediction in mrpc_eval_preds:
                if prediction == '2':
                    mrpc_predictions.append(2)  # Class 2 (not equivalent)
                if prediction == '1':
                    mrpc_predictions.append(1)  # Class 1 (equivalent)

            mrpc_references = mrpc_dataset["validation"]["label"]

            # Compute additional metrics if predictions and references match
            if len(mrpc_references) == len(mrpc_predictions):
                mrpc_results = mrpc_metric.compute(
                    predictions=mrpc_predictions,
                    references=mrpc_references
                )
                print(
                    f"\nRound: {Round} Global Results for Model on MRPC dataset on aggregating {number_of_clients} clients: {mrpc_results}"
                )
                global_acc_mrpc.append([mrpc_global_accuracy, mrpc_results])
            else:
                print(
                    f"\nRound: {Round} There is no result for MRPC dataset"
                )"""
            global_acc_mrpc.append(mrpc_global_accuracy)


            # Evaluation on stsb validation set
            print(f"\nstsb dataset, global predictions: {stsb_eval_preds}")

            

            # Evaluation on qqp validation set
            # Initialize counters for each class
            class_correct_qqp = {0: 0, 1: 0}  # Correct predictions per class
            class_total_qqp = {0: 0, 1: 0}    # Total samples per class

            qqp_correct = 0
            qqp_total = 0

            print(f"\nQQP dataset, global predictions: {qqp_eval_preds}")

            for pred, true in zip(qqp_eval_preds, qqp_dataset["validation"]["label"]):
                qqp_total += 1
                if pred.strip() == "":
                    continue  # Skip empty predictions
                if len(pred) != 1:
                    continue
                if pred not in ['1', '2']:
                    continue

                # Convert predictions to class indices (1 -> 1, 2 -> 2)
                pred_class = int(pred)
                true_class = int(true)

                # Update total counts for true class
                class_total_qqp[true_class] += 1

                # Check for correctness and update counts
                if int(pred) == int(true):
                    qqp_correct += 1
                    class_correct_qqp[true_class] += 1
                elif int(pred) == int(true) + 2:
                    qqp_correct += 1
                    class_correct_qqp[true_class] += 1


            # Calculate global accuracy
            qqp_global_accuracy = qqp_correct / qqp_total * 100
            print(
                f"\nRound: {Round} Global Accuracy for Model on QQP dataset on aggregating {number_of_clients} clients: {qqp_global_accuracy}%"
            )

            # Calculate per-class accuracies
            class_accuracies_qqp = {
                cls: (class_correct_qqp[cls] / class_total_qqp[cls] * 100 if class_total_qqp[cls] > 0 else 0)
                for cls in class_total_qqp
            }

            ds_name = 'qqp'  # Dataset name

            # Iterate over the computed class accuracies
            print("\nClass-wise Accuracies for QQP:")
            for cls, acc in class_accuracies_qqp.items():
                print(f"Class {cls}: {acc:.2f}%")
                # Append the accuracy to the corresponding class in the dictionary
                if ds_name in cl_acc_dict and str(cls) in cl_acc_dict[ds_name]:
                    cl_acc_dict[ds_name][str(cls)].append(acc)
                else:
                    print(f"Warning: Class {cls} not found in {ds_name} in cl_acc_dict")

            # Convert predictions for metric computation
            qqp_predictions = []
            for prediction in qqp_eval_preds:
                if prediction == '2':
                    qqp_predictions.append(2)  # Class 2 (not duplicate)
                if prediction == '1':
                    qqp_predictions.append(1)  # Class 1 (duplicate)

            qqp_references = qqp_dataset["validation"]["label"]

            # Compute additional metrics if predictions and references match
            """if len(qqp_references) == len(qqp_predictions):
                qqp_results = qqp_metric.compute(
                    predictions=qqp_predictions,
                    references=qqp_references
                )
                print(
                    f"\nRound: {Round} Global Results for Model on QQP dataset on aggregating {number_of_clients} clients: {qqp_results}"
                )
                global_acc_qqp.append([qqp_global_accuracy, qqp_results])
            else:
                print(
                    f"\nRound: {Round} There is no global result for QQP dataset"
                )"""
            global_acc_qqp.append(qqp_global_accuracy)


            # Evaluation on mnli validation set
            # Initialize counters for MNLI
            class_correct_mnli = {0: 0, 1: 0, 2: 0}  # Correct predictions per class
            class_total_mnli = {0: 0, 1: 0, 2: 0}    # Total samples per class

            correct_mnli = 0
            total_mnli = 0

            print(f"\nMNLI dataset, global predictions: {mnli_eval_preds}")

            for pred, true in zip(mnli_eval_preds, mnli_dataset["validation_matched"]["label"]):
                total_mnli += 1

                if pred.strip() == "":
                    continue  # Skip empty predictions
                if pred not in ['0', '1', '2','3', '4', '5']:
                    continue  # Skip invalid predictions

                # Convert predictions and true labels to integers
                pred_class = int(pred)
                true_class = int(true)
                class_total_mnli[true_class] += 1

                if int(pred.strip()) == int(true) + 1:
                    correct_mnli += 1
                    class_correct_mnli[true_class] += 1

            # Calculate global accuracy
            mnli_global_accuracy = correct_mnli / total_mnli * 100
            global_acc_mnli.append(mnli_global_accuracy)
            print(
                f"\nRound: {Round} Global Accuracy for Model on MNLI dataset on aggregating {number_of_clients} clients: {mnli_global_accuracy}%"
            )

            # Calculate per-class accuracies
            class_accuracies_mnli = {
                cls: (class_correct_mnli[cls] / class_total_mnli[cls] * 100 if class_total_mnli[cls] > 0 else 0)
                for cls in class_total_mnli
            }
            ds_name = 'mnli'
            print("\nClass-wise Accuracies for MNLI:")
            for cls, acc in class_accuracies_mnli.items():
                print(f"Class {cls}: {acc:.2f}%")
                # Append the accuracy to the corresponding class in the dictionary
                if ds_name in cl_acc_dict and str(cls) in cl_acc_dict[ds_name]:
                    cl_acc_dict[ds_name][str(cls)].append(acc)
                else:
                    print(f"Warning: Class {cls} not found in {ds_name} in cl_acc_dict")

            # Evaluation on qnli validation set
            # Initialize counters for each class
            class_correct_qnli = {0: 0, 1: 0}  # Correct predictions per class
            class_total_qnli = {0: 0, 1: 0}    # Total samples per class

            qnli_correct = 0
            qnli_total = 0

            print(f"\nQNLI dataset, global predictions: {qnli_eval_preds}")

            for pred, true in zip(qnli_eval_preds, qnli_dataset["validation"]["label"]):
                qnli_total += 1
                if pred.strip() == "":
                    continue  # Skip empty predictions
                if len(pred) != 1:
                    continue
                if pred not in ['1', '2']:
                    continue

                # Convert predictions to class indices (1 -> 1, 2 -> 2)
                pred_class = int(pred)
                true_class = int(true)

                # Update total counts for true class
                class_total_qnli[true_class] += 1

                # Check for correctness and update counts
                if pred_class == true_class:
                    qnli_correct += 1
                    class_correct_qnli[true_class] += 1
                elif int(pred) == int(true) + 2:
                    qnli_correct += 1
                    class_correct_qnli[true_class] += 1


            # Calculate global accuracy
            qnli_global_accuracy = qnli_correct / qnli_total * 100
            print(
                f"\nRound: {Round} Global Accuracy for Model on QNLI dataset on aggregating {number_of_clients} clients: {qnli_global_accuracy}%"
            )

            # Calculate per-class accuracies
            class_accuracies_qnli = {
                cls: (class_correct_qnli[cls] / class_total_qnli[cls] * 100 if class_total_qnli[cls] > 0 else 0)
                for cls in class_total_qnli
            }
            ds_name = 'qnli'
            # Iterate over the computed class accuracies
            print("\nClass-wise Accuracies for QNLI:")
            for cls, acc in class_accuracies_qnli.items():
                print(f"Class {cls}: {acc:.2f}%")
                # Append the accuracy to the corresponding class in the dictionary
                if ds_name in cl_acc_dict and str(cls) in cl_acc_dict[ds_name]:
                    cl_acc_dict[ds_name][str(cls)].append(acc)
                else:
                    print(f"Warning: Class {cls} not found in {ds_name} in cl_acc_dict")

            # Convert predictions for metric computation
            qnli_predictions = []
            for prediction in qnli_eval_preds:
                if prediction == '2':
                    qnli_predictions.append(2)  # Class 2
                if prediction == '1':
                    qnli_predictions.append(1)  # Class 1

            qnli_references = qnli_dataset["validation"]["label"]

            # Compute additional metrics if predictions and references match
            """if len(qnli_references) == len(qnli_predictions):
                qnli_results = qnli_metric.compute(
                    predictions=qnli_predictions,
                    references=qnli_references
                )
                print(
                    f"\nRound: {Round} Global Results for Model on QNLI dataset on aggregating {number_of_clients} clients: {qnli_results}"
                )
                global_acc_qnli.append([qnli_global_accuracy, qnli_results])
            else:
                print(
                    f"\nRound: {Round} There is no global result for QNLI dataset"
                )"""
            global_acc_qnli.append(qnli_global_accuracy)


            # Evaluation on rte validation set
            # Initialize counters for each class in RTE
            class_correct_rte = {0: 0, 1: 0}  # Correct predictions per class
            class_total_rte = {0: 0, 1: 0}    # Total samples per class

            rte_correct = 0
            rte_total = 0

            print(f"\nRTE dataset, global predictions: {rte_eval_preds}")

            for pred, true in zip(rte_eval_preds, rte_dataset["validation"]["label"]):
                rte_total += 1
                if pred.strip() == "":
                    continue  # Skip empty predictions
                if len(pred) != 1:
                    continue
                if pred not in ['1', '2']:
                    continue

                pred_class = int(pred)
                true_class = int(true)

                # Update total counts for true class
                class_total_rte[true_class] += 1

                # Check for correctness and update counts
                if pred_class == true_class:
                    rte_correct += 1
                    class_correct_rte[true_class] += 1
                elif int(pred) == int(true) + 2:
                    rte_correct += 1
                    class_correct_rte[true_class] += 1

            # Calculate global accuracy
            rte_global_accuracy = rte_correct / rte_total * 100
            print(
                f"\nRound: {Round} Global Accuracy for Model on RTE dataset on aggregating {number_of_clients} clients: {rte_global_accuracy}%"
            )

            # Calculate per-class accuracies
            class_accuracies_rte = {
                cls: (class_correct_rte[cls] / class_total_rte[cls] * 100 if class_total_rte[cls] > 0 else 0)
                for cls in class_total_rte
            }

            ds_name = 'rte'  # Dataset name

            # Print and append class-wise accuracies
            print("\nClass-wise Accuracies for RTE:")
            for cls, acc in class_accuracies_rte.items():
                print(f"Class {cls}: {acc:.2f}%")
                if ds_name in cl_acc_dict and str(cls) in cl_acc_dict[ds_name]:
                    cl_acc_dict[ds_name][str(cls)].append(acc)
                else:
                    print(f"Warning: Class {cls} not found in {ds_name} in cl_acc_dict")

            global_acc_rte.append(rte_global_accuracy)


            
            # Save the global model's parameters
            global_model_path = os.path.join(
                model_dir, f'global_model_{number_of_clients}_clients.pt')
            torch.save(global_params, global_model_path)

            # torch.save({
            #     'model_state_dict': global_model.state_dict(),
            # }, global_model_path)

            prev_global_model_path = global_model_path

            # delete the model and clear the cache
            del global_model
            torch.cuda.empty_cache()

            is_cache_cleared = torch.cuda.memory_allocated() == 0
            is_memory_cleared = torch.cuda.memory_summary(device=None).strip()

            path = f'{model_dir}/glue_FedProx.txt'
            print('global_acc_cola=', global_acc_cola)
            print('global_acc_sst=', global_acc_sst)
            print('global_acc_mrpc=', global_acc_mrpc)
            print('global_acc_qqp=', global_acc_qqp)
            print('global_acc_mnli=', global_acc_mnli)
            print('global_acc_qnli=', global_acc_qnli)
            print('global_acc_rte=', global_acc_rte)
           
            with open(path, "a") as f:
                f.write('cola dataset')
                f.write(str(global_acc_cola) + '\n')
                f.write('\n')
                f.write(
                    '\n-------------------------------------------------------------------------------------------\n')
                
                f.write('sst dataset')
                f.write(str(global_acc_sst) + '\n')
                f.write('\n')
                f.write(
                    '\n-------------------------------------------------------------------------------------------\n')
                
                f.write('mrpc dataset')
                f.write(str(global_acc_mrpc) + '\n')
                f.write('\n')
                f.write(
                    '\n-------------------------------------------------------------------------------------------\n')

                f.write('qqp dataset')
                f.write(str(global_acc_qqp) + '\n')
                f.write('\n')
                f.write(
                    '\n-------------------------------------------------------------------------------------------\n')

                f.write('mnli dataset')
                f.write(str(global_acc_mnli) + '\n')
                f.write('\n')
                f.write(
                    '\n-------------------------------------------------------------------------------------------\n')

                f.write('qnli dataset')
                f.write(str(global_acc_qnli) + '\n')
                f.write('\n')
                f.write(
                    '\n-------------------------------------------------------------------------------------------\n')

                f.write('rte dataset')
                f.write(str(global_acc_rte) + '\n')
                f.write('\n')
                f.write(
                    '\n-------------------------------------------------------------------------------------------\n')


                f.write(str(number_cola_rounds) + '\n')
                f.write('\n')
                f.write(str(number_sst_rounds) + '\n')
                f.write('\n')
                f.write(str(number_mrpc_rounds) + '\n')
                f.write('\n')
                f.write(str(number_qqp_rounds) + '\n')
                f.write('\n')
                f.write(str(number_mnli_rounds) + '\n')
                f.write('\n')
                f.write(str(number_qnli_rounds) + '\n')
                f.write('\n')
                f.write(str(number_rte_rounds) + '\n')
                f.write('\n')

file_path = "cl_acc_dict.txt"
import json
d = {}
with open('class_accs.txt', 'w') as cfile:
    cfile.write(json.dumps(str(cl_acc_dict)))

print(f"\ncl_acc_dict has been saved to {file_path}.")                

